# =============================================================================
# Contains several functions that are used across projects and improved periodically.
# =============================================================================
from __main__ import *

import pandas as pd
import numpy as np
import time
from joblib import Parallel, delayed
from tqdm import tqdm
import functools
import re

# =============================================================================
# Custom Functions
def downsize(df,types):
    for var in list(set(list(df)) & set(list(types))):
        #print(var)
        if not isnan2(types[var]):
            if types[var]=='float':
                df[var]=pd.to_numeric(df[var],downcast=types[var],errors='coerce')
            if types[var]=='float_nd':
                df[var]=pd.to_numeric(df[var],errors='coerce')
            elif types[var]=='integer':
                df[var].replace(regex=True,inplace=True,to_replace=r'-',value=r'')
                df[var]=pd.to_numeric(df[var],downcast=types[var],errors='coerce')
            elif types[var]=='integer_nd':
                df[var].replace(regex=True,inplace=True,to_replace=r'-',value=r'')
                df[var]=pd.to_numeric(df[var],errors='coerce')
            elif types[var]=='cat':
                df[var]=df[var].astype('category')
            elif types[var]=='date':
                df[var]=pd.to_datetime(df[var], format='%Y-%m-%d',errors='coerce')
            elif types[var]=='date_comp':
                df[var]=pd.to_datetime(df[var])
    return df


def resample(df,n,nunique=100000000,method='random_flatdensity'):
    if method=='random_preservedensity':
        dfr = pd.DataFrame(np.random.randint(1,nunique,(n,df.shape[1])),columns=list(df)).mod(df.shape[0])
        for v in list(df):
            dfr[v] = dfr[v].map(df[v])
    if method=='random_flatdensity':
        dfr = pd.DataFrame(np.random.randint(1,nunique,(n,df.shape[1])),columns=list(df))
        for v in list(df):
            unique = df[[v]].drop_duplicates().reset_index(drop=True)
            dfr[v] = dfr[v].mod(unique.shape[0]).map(unique[v])
    return dfr


#Retain within groups using my custom interpolation method
def fill_eff(df,vfill,gp,sort,limit=1000,par=False,par_ng=20):
    '''
    Method requires the sorting variables and the order for the first / second sort in a dictionary
    '''
    if (type(gp)==str):
        gp = [gp]
    if par==False:
        df.sort_values(by=sort['svars'], ascending=sort['sort1'], inplace=True)
        if type(gp) == type([]):
            df['temporaryone'] = df.groupby(gp).ngroup()
        else:
            df['temporaryone'] = df[gp]
        df['temporarytwo'] = 1
        df['temporarytwo'] = df['temporarytwo'].cumsum()
        df.set_index(keys='temporaryone',inplace=True)
        data_types = df.dtypes.reset_index()
        cats = list(data_types[data_types[0]=="category"]['index'].unique())
        for v in cats:
            df[v] = df[v].astype(str)
        tiebreaker = {'sort1':[1],'sort2':[0]}
        for s in ['sort1','sort2']:
            if s in sort.keys():
                df.sort_values(by=sort['svars']+['temporarytwo'], ascending=sort[s] + tiebreaker[s], inplace=True)
                i=0
                tt=0
                for v in vfill:
                    t = time.time()
                    print(s + ': variable ' + v + ' progress is ' + str(i) + '/' + str(len(vfill)) + ' time taken for last loop is ' + str(np.round(tt,0)))
                    df['temp2']=(1 - 1 * (df[v].fillna('')=='')).groupby(level=0).cumsum()>0
                    df[v] = np.where(df['temp2']>0,df[v].fillna(method='ffill', limit=limit),df[v])
                    tt = time.time() - t
                    i += 1
        for v in cats:
            df[v] = df[v].astype("category")
        df.reset_index(drop=True,inplace=True)
        df.drop(columns=['temporarytwo','temp2'],inplace=True)
    if par==True:
        print("parallel approach")
        svar = gp[0]
        tgtl = split(df[svar].unique(),par_ng)
        temp = Parallel(n_jobs=10)(delayed(fill_eff)(df = df[df[svar].isin(tgt)].copy(),vfill=vfill,gp=gp,sort=sort,limit=limit) for tgt in tqdm(tgtl))
        df = pd.concat(temp, ignore_index=True,sort=False)
    return df


def isnan2(x):
    try:
        a=np.isnan(x)
    except TypeError:
        a=False
    return a

def split(a, n):
    try:
        a.sort()
    except Exception:
        pass
            #    k, m = divmod(len(a), n)
            #    #List that maintains the order of a
            #    ordered_list = list((a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)))
    #List that distributes the values evenly
    shell = [[] for i in range(n)]
    i=0
    for v in a:
        shell[i] += [v]
        i = divmod(i+1,n)[1]
    return shell



def ldiff(a,b,method='fast'):
    try:
        if method=='fast':
            return list(set(a)-set(b))
        if method=='slow':
            return [v for v in a if (v not in b)]
    except Exception:
        return [v for v in a if (v not in b)]

def lint(a,b,method='fast'):
    try:
        if method=='fast':
            return list(set(a).intersection(set(b)))
        if method=='slow':
            return [v for v in a if (v in b)]
    except Exception:
        return [v for v in a if (v in b)]

def lunion(a,b,method='fast'):
    return list(set(a+b))


def add_desc(df,desc_dict):
    for k in desc_dict.keys():
        if type(desc_dict[k])!=list:
            df[k]=desc_dict[k]
        if type(desc_dict[k])==list:
            df[k]=np.tile(desc_dict[k], (len(df),1)).tolist()
    return df



#This will limit the impact of relationships between two bankers at the same bank which raise the syndicate relationship strength without much impact.
def clip_series(s, lower = 0, upper = 1, ctype = 'quantile'):
    if ctype=='quantile':
        clipped = s.clip(lower=s.quantile(lower), upper=s.quantile(upper))
    if ctype=='absolute':
        clipped = s.clip(lower=lower, upper=upper)
    return clipped


#Short function to flatten arbitrarily nested lists (https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-list-of-lists)
#   Warning, this requires the same number of levels for all sublists.
def functools_reduce_iconcat(a,levels=1):
    ans=a
    for i in range(levels):
        ans = functools.reduce(operator.iconcat, ans, [])
    return ans

def flatten(l):
    return [item for sublist in l for item in sublist]


def bina(v):
    return np.where(v>0,1*(v>0),v)



def pkl2parquet(file,drop_bad=False,drop_cols=[]):
    temp = pd.read_pickle(file)
    try:
        temp[ldiff(list(temp),drop_cols)].to_parquet(re.sub(r'.pkl','.parquet',file))
        return "success"
    except Exception:
        pass
    bad_col = []
    for v in tqdm(list(temp)):
        print(v)
        try:
            temp.loc[:min(temp.shape[0],10000)][[v]].to_parquet(re.sub(r'.pkl','.parquet',file))
        except Exception:
            bad_col += [v]
    if drop_bad==True:
        temp[ldiff(list(temp),bad_col)].to_parquet(re.sub(r'.pkl','.parquet',file))
    return bad_col


import itertools
def bal_panel(index):
    """
    This function takes the cross product of all of the lists in the index dictionary and returns a dataframe with a column for each entry.
    """
    temp = []
    for element in itertools.product(*(index.values())):
        temp += [element]
    df = pd.DataFrame(temp,columns=index.keys())
    return df


def getListOfFiles(dirName):
    # create a list of file and sub directories 
    # names in the given directory 
    listOfFile = os.listdir(dirName)
    allFiles = list()
    # Iterate over all the entries
    for entry in listOfFile:
        # Create full path
        fullPath = os.path.join(dirName, entry)
        # If entry is a directory then get the list of files in this directory 
        if os.path.isdir(fullPath):
            allFiles = allFiles + getListOfFiles(fullPath)
        else:
            allFiles.append(fullPath)
    return allFiles